For part 5, I used the layers up to and including layer3 in ResNet18 as the backbone and then tried quite a few different architectures for the regression network. The custom ImageRegModel described below is setup to use the backbone on both images, concat the output from the backbone together, and then use the regression part of the network in order to get an output of size 2 for the predicted x_shift and y_shift. For the Regression part of the network I tried several different layers and settings and ended up having the best performance by using the layers below. I also added in the dropout in order to increase regularization because my models tended to overfit on the training set. I also used a few suggestions from here: https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html such as disabling the bias in the Conv2d that is followed by the BatchNorm2d in order to speed up training because it was taking a moment for the model to train on my machine. I referenced the docs on pytorch.org quite a bit to understand certain layers and parameters and what impact they have to try to improve overall training.
I used and modified the extract_patches_training.py and extract_patches_testing.py in order to create the training and testing sets of images from the cell images. I also created a extract_patches_validation.py which I used to create a validation set for training in order to track model progress and overfitting and underfitting. The training set contained patches from cell images 0001.000, 0001.001, and 0001.002. The validation set contained patches from cell image 0001.004. And the testing set contained patches from cell images 0001.005 and 0001.006.
I played with different hyperparameters / model architecture quite a bit over ~10 runs in order to find what seemed to work best which I tracked with tensorboard (Those runs are documented in the runs folder). As mentioned above, I did have some trouble with the model overfitting and so I also modified the utils.py RegistrationDatasetLoader to use a passed in transform that included randomized ColorJitter and randomized GaussianBlur. In order to apply the same transformation to both input images, I had to manipulate the RNG state of torch to apply the same transformation to both. I had to look this up specifically because I couldn't find the information searching directly on pytorch.org. This lead me to here: https://discuss.pytorch.org/t/torchvision-transfors-how-to-perform-identical-transform-on-both-image-and-target/10606/2 and then to here: https://github.com/pytorch/vision/issues/9 which had an example close to the bottom of how to use get_rng_state and set_rng_state in order to do that. I also used different parameters for the dataloaders in order to speed up training time by increasing the batch size and pinning memory. I used pytorch.org a lot to go through this information and settings while working on improving my models.
I trained for a full epoch and then evaluated on the validation set after each epoch. Below that you can see the MSELoss on the testing set specifically and also examples of the predicted x,y shift versus the actual x,y shift on images from the testing dataset and also images from the training dataset. You can see that for my model it does a decent job on the training set, but doesn't always generalize as well to the testing set as much as I would like. Some offsets are close and the images look pretty well aligned while other ones are pretty off in the testing set. This project was very interesting and definitely learned a lot on how best to train and test the custom model.
import time
import torch
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from utils import RegistrationDatasetLoader
from torch.utils.tensorboard import SummaryWriter
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
class ImageRegModel(torch.nn.Module):
def __init__(self):
super(ImageRegModel, self).__init__()
rn18 = resnet18(weights='DEFAULT')
# resnet18 up to and including layer3
self.backbone = torch.nn.Sequential(*list(rn18.children())[:-3])
del rn18
# Disable conv2d bias because of batchnorm
self.regression = torch.nn.Sequential(
torch.nn.Conv2d(512, 32, 3, 1, 1, bias=False),
torch.nn.BatchNorm2d(32),
torch.nn.MaxPool2d(5),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Dropout(0.6),
torch.nn.Linear(288,2),
)
def forward(self, img1, img2):
# Shape should be (n, c, m, n)
# where c is channels and n is num examples
# m and n is width and height
# Use the backbone on both images
i1 = self.backbone(img1)
i2 = self.backbone(img2)
# Concat tensors together
a = torch.cat([i1,i2], dim=1)
# Regression network prediction
o = self.regression(a)
return o
# Removed max pooling
# This model does not perform as well.
class ImageRegModel2(torch.nn.Module):
def __init__(self):
super(ImageRegModel2, self).__init__()
rn18 = resnet18(weights='DEFAULT')
# resnet18 up to and including layer3
self.backbone = torch.nn.Sequential(*list(rn18.children())[:-3])
del rn18
# Disable conv2d bias because of batchnorm
self.regression = torch.nn.Sequential(
torch.nn.Conv2d(512, 32, 3, 1, 1, bias=False),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Dropout(0.6),
torch.nn.Linear(8192,2),
)
def forward(self, img1, img2):
# Shape should be (n, c, m, n)
# where c is channels and n is num examples
# m and n is width and height
# Use the backbone on both images
i1 = self.backbone(img1)
i2 = self.backbone(img2)
# Concat tensors together
a = torch.cat([i1,i2], dim=1)
# Regression network prediction
o = self.regression(a)
return o
# Changes following pytorch.org tutorials to speed up training
# like not calling cpu every time.
dtype = torch.float32 ## instialize the data types used in training
cpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64
def display_training_and_validation_loss(dir_path):
# Since I was using tensorboard to track runs, I had to look up how to access the data from tensorboard
# It lead me to some tensorflow documentation and then to here:
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/backend/event_processing/event_accumulator.py
# EventAccumulator provides a way to access data from the run
ea = EventAccumulator(dir_path)
# From that link above, the Reload() method loads all the data for the run
ea.Reload()
# This way we can access the data written during the run and display it
training_loss = ea.Scalars("training loss")
validation_loss = ea.Scalars("validation loss")
# Get data for training loss
training_loss_step = []
training_loss_loss = []
for i in range(len(training_loss)):
training_loss_step.append(training_loss[i].step)
training_loss_loss.append(training_loss[i].value)
# Get data for validation loss
validation_loss_step = []
validation_loss_loss = []
for i in range(len(validation_loss)):
validation_loss_step.append(validation_loss[i].step)
validation_loss_loss.append(validation_loss[i].value)
# Display loss curves
plt.figure()
plt.plot(validation_loss_step, validation_loss_loss, label='val loss')
plt.plot(training_loss_step, training_loss_loss, label='train loss')
plt.legend()
plt.title('Training and val loss')
plt.xlabel('Step')
plt.ylabel('Avg Loss')
plt.show()
def checkTestingMse(model, dataloader):
print()
print()
print(f'**** Running model on the testing dataset ****')
model.eval()
running_loss = 0.0
loss_function = nn.MSELoss()
for t, temp in enumerate(dataloader):
referenceImage = temp["ref_image"].to(device=cpu, dtype=dtype)
inputImage = temp["inputImage"].to(device=cpu, dtype=dtype)
x_shift = temp['x-shift']
y_shift = temp['y-shift']
# Running image thru model
out = model(referenceImage, inputImage)
# Concat x, y together
shift = torch.tensor(np.column_stack([x_shift,y_shift])).to(device=cpu, dtype=dtype)
# Calculate loss, backprop, optimize
loss = loss_function(out, shift)
running_loss += loss.item() * len(x_shift)
print(f'Test total loss: {running_loss}')
print(f'Test avg loss: {running_loss/len(dataloader.dataset)}')
return running_loss
def get_num_channels(img):
# Check for grayscale
if len(img.shape) < 3:
return 1
return img.shape[-1]
def convert_grayscale(image):
weights = [0.3, 0.6, 0.1]
channels = get_num_channels(image)
if channels == 1: # Grayscale
return image
elif channels == 3: # RGB
return (image @ weights) / 255.0
elif channels == 4: # ARGB
# Only use the first 3 channels
return (image[:,:,:3] @ weights) /255.0
def get_fitted_image(inp_img):
mask = (inp_img > 0).astype(int)
min_i = np.inf
min_j = np.inf
max_i = -1 * np.inf
max_j = -1 * np.inf
for i in range(inp_img.shape[0]):
for j in range(inp_img.shape[1]):
if mask[i,j] > 0:
if i < min_i:
min_i = i
if j < min_j:
min_j = j
if mask[i,j] > 0:
if i > max_i:
max_i = i
if j > max_j:
max_j = j
fitted_canvas = inp_img[min_i:max_i,min_j:max_j]
return fitted_canvas
def display_out(data, axs):
f_img = convert_grayscale(data[0].permute(1,2,0).numpy())
g_img = convert_grayscale(data[1].permute(1,2,0).numpy())
axs[0].imshow(f_img, cmap="gray", origin='lower')
axs[0].set_title("f image")
axs[1].imshow(g_img, cmap="gray", origin='lower')
axs[1].set_title("g image")
c,m,n = data[0].shape
canvas = np.zeros((3*m, 3*n))
canvas[m:2*m, n:2*n] = f_img
x_shift, y_shift = data[2], data[3]
canvas[m + x_shift:2*m + x_shift, n + y_shift:2*n + y_shift] = g_img
axs[2].imshow(canvas, cmap="gray", origin='lower')
axs[2].set_title("actual alignment")
axs[2].set_xlabel(f'warped by ({data[2]},{data[3]})')
axs[3].imshow(get_fitted_image(canvas), cmap="gray", origin='lower')
axs[3].set_title("actual alignment")
axs[3].set_xlabel(f'warped by ({data[2]},{data[3]})')
# print(data[5])
x_shift_pred, y_shift_pred = data[5].numpy()
x_shift_pred, y_shift_pred = int(x_shift_pred), int(y_shift_pred)
# print(x_shift_pred)
# print(y_shift_pred)
c,m,n = data[0].shape
canvas = np.zeros((3*m, 3*n))
canvas[m:2*m, n:2*n] = f_img
x_shift, y_shift = data[2], data[3]
canvas[m + x_shift_pred:2*m + x_shift_pred, n + y_shift_pred:2*n + y_shift_pred] = g_img
axs[4].imshow(get_fitted_image(canvas), cmap="gray", origin='lower')
axs[4].set_title("predicted alignment")
axs[4].set_xlabel(f'warped by ({x_shift_pred},{y_shift_pred})')
return
def display_output_images(model, dataloader, train=False):
# Get images to display from test set
f_images = []
g_images = []
x_shift_arr = []
y_shift_arr = []
predictions = []
mse_loss = []
num_images = 4
model.eval()
loss_function = nn.MSELoss()
for t, temp in enumerate(dataloader):
if t >= num_images:
break
referenceImage = temp["ref_image"].to(device=cpu, dtype=dtype)
inputImage = temp["inputImage"].to(device=cpu, dtype=dtype)
x_shift = temp['x-shift']
y_shift = temp['y-shift']
if referenceImage.shape[0] > 1:
referenceImage = referenceImage[:1]
if inputImage.shape[0] > 1:
inputImage = inputImage[:1]
if x_shift.shape[0] > 1:
x_shift = x_shift[:1]
if y_shift.shape[0] > 1:
y_shift = y_shift[:1]
# Running image thru model
out = model(referenceImage, inputImage)
# Concat x, y together
shift = torch.tensor(np.column_stack([x_shift,y_shift])).to(device=cpu, dtype=dtype)
# Calculate loss, backprop, optimize
loss = loss_function(out, shift)
f_images.append(referenceImage.cpu().detach().squeeze())
g_images.append(inputImage.cpu().detach().squeeze())
x_shift_arr.append(x_shift.squeeze())
y_shift_arr.append(y_shift.squeeze())
mse_loss.append(loss.cpu().detach())
predictions.append(out.cpu().detach().squeeze())
rows = num_images
cols = 5
fig, axs = plt.subplots(rows, cols, figsize=(16, 14))
for i, ax in enumerate(axs):
display_out([f_images[i], g_images[i], x_shift_arr[i], y_shift_arr[i], mse_loss[i], predictions[i]], ax)
dataloader_type = 'testing' if not train else 'training'
plt.suptitle(f'Output {dataloader_type} Examples')
plt.tight_layout()
plt.savefig(f'model-predictions-{dataloader_type}-examples.png')
def trainingLoop(model, optimizer, nepochs, train_dataloader, val_dataloader, test_dataloader):
curr_time = int(time.time())
run_dir = f'runs/'
run_name = f'run_{curr_time}'
writer = SummaryWriter(run_dir + run_name)
print(f'**** Starting {run_name} ****')
total_records = len(train_dataloader.dataset)
total_val_records = len(val_dataloader.dataset)
model = model.to(device=cpu)
loss_function = nn.MSELoss()
current_step = 0
for ep in range(nepochs):
print("Epoch", ep)
model.train()
running_loss = 0.0
for t, temp in enumerate(train_dataloader):
referenceImage = temp["ref_image"].to(device=cpu, dtype=dtype)
inputImage = temp["inputImage"].to(device=cpu, dtype=dtype)
x_shift = temp['x-shift']
y_shift = temp['y-shift']
# zero gradient
optimizer.zero_grad()
# Running images thru model
out = model(referenceImage, inputImage)
# Concat x, y together
shift = torch.as_tensor(np.column_stack([x_shift,y_shift]), dtype=dtype, device=cpu)
# print(shift)
# Calculate loss, backprop, optimize
loss = loss_function(out, shift)
loss.backward()
optimizer.step()
# Track running loss
running_loss += loss.item() * len(x_shift)
if t % 20 == 19:
# Print current training loss and write training loss to tensorboard for tracking
# Followed this and other tutorials on pytorch.org in order to figure out how to use
# tensorboard with torch https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
loss, current = loss.item() * len(x_shift), (t + 1) * batch_size
print(f"training loss: {loss:>7f} [{current:>5d}/{total_records:>5d}]")
current_step = (ep * total_records) + ((t + 1) * batch_size)
writer.add_scalar('training loss', running_loss / ((t + 1)*batch_size), (ep * total_records) + ((t + 1) * batch_size))
print(f"training avg loss: {running_loss / total_records:>7f} [{total_records:>5d}/{total_records:>5d}]")
current_step = ((ep+1) * total_records)
writer.add_scalar('training loss', running_loss / total_records, current_step)
# Validation step
model.eval()
running_loss = 0.0
for t, temp in enumerate(val_dataloader):
referenceImage = temp["ref_image"].to(device=cpu, dtype=dtype)
inputImage = temp["inputImage"].to(device=cpu, dtype=dtype)
x_shift = temp['x-shift']
y_shift = temp['y-shift']
# Running image thru model
out = model(referenceImage, inputImage)
# Concat x, y together
shift = torch.tensor(np.column_stack([x_shift,y_shift])).to(device=cpu, dtype=dtype)
# Calculate loss, backprop, optimize
loss = loss_function(out, shift)
# Track running loss
running_loss += loss.item() * len(x_shift)
print(f"validation avg loss: {running_loss / total_val_records:>7f}\n")
writer.add_scalar('validation loss', running_loss / total_val_records, current_step)
print(f'saving model')
torch.save(model, f'{run_dir}{run_name}/model_epoch{ep}_{running_loss / total_val_records:>7f}.pth')
# Display the loss curves
display_training_and_validation_loss(run_dir + run_name + '/')
# Run model on testing dataset
checkTestingMse(model, test_dataloader)
# Display outputs for the test dataloader
display_output_images(model, test_dataloader)
# Display outputs for the train dataloader for fun
display_output_images(model, train_dataloader, True)
return True
# Additional transformations in order to reduce overfitting and improve generalization
transform = transforms.Compose([
transforms.ToTensor(),
transforms.ColorJitter(brightness=.5, contrast=.5, hue=.3),
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)),
])
TrainingSet = RegistrationDatasetLoader(csv_file="TrainingDatasetClassification.csv", transform=transform)
TestingSet = RegistrationDatasetLoader(csv_file="TestingDatasetClassification.csv", transform=transform)
ValidationSet = RegistrationDatasetLoader(csv_file="ValidationDatasetClassification.csv", transform=transform)
## DataLoader is a pytorch Class for iterating over a dataset
dataloader_train = DataLoader(TrainingSet, batch_size=batch_size, num_workers=4, shuffle=True, pin_memory=True)
dataloader_test = DataLoader(TestingSet, batch_size=1)
dataloader_val = DataLoader(ValidationSet, batch_size=32)
Below here are different training runs with some different hyperparameter settings and the other model as well.
This model overfitted on the training dataset. You can see that the training curve dropped while the validation curve dropped and then went back up as the model continued training. I found in training all these models that they tended to overfit relatively quickly, and so I tried to make changes to combat that. You can see from the displayed transformations below that it doesn't really do a very good job on the testing set and it does an okay job on the training set.
learningRate = 1e-3 # Set this value between 1e-2 to 1e-4
weightDecay = 1e-3 # Set this value between 1e-2 to 1e-4
momentum = 0.9 # Adding momentum to potentially improve training using SGD
epochs = 5
model = ImageRegModel()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=momentum)
res = trainingLoop(model, optimizer, epochs, dataloader_train, dataloader_val, dataloader_test)
**** Starting run_1699667033 **** Epoch 0 training loss: 92676.781250 [ 1280/11582] training loss: 89695.687500 [ 2560/11582] training loss: 79758.968750 [ 3840/11582] training loss: 84336.679688 [ 5120/11582] training loss: 79116.648438 [ 6400/11582] training loss: 63972.156250 [ 7680/11582] training loss: 51610.710938 [ 8960/11582] training loss: 32668.437500 [10240/11582] training loss: 36707.914062 [11520/11582] training avg loss: 1069.949843 [11582/11582] validation avg loss: 998.301141 saving model Epoch 1 training loss: 25147.777344 [ 1280/11582] training loss: 32159.830078 [ 2560/11582] training loss: 28557.291016 [ 3840/11582] training loss: 15403.156250 [ 5120/11582] training loss: 22531.144531 [ 6400/11582] training loss: 18563.750000 [ 7680/11582] training loss: 14980.917969 [ 8960/11582] training loss: 16294.951172 [10240/11582] training loss: 17715.789062 [11520/11582] training avg loss: 382.273082 [11582/11582] validation avg loss: 895.559807 saving model Epoch 2 training loss: 20515.986328 [ 1280/11582] training loss: 17483.546875 [ 2560/11582] training loss: 19725.820312 [ 3840/11582] training loss: 17163.648438 [ 5120/11582] training loss: 12853.043945 [ 6400/11582] training loss: 13196.402344 [ 7680/11582] training loss: 10533.628906 [ 8960/11582] training loss: 11656.175781 [10240/11582] training loss: 17671.285156 [11520/11582] training avg loss: 241.619570 [11582/11582] validation avg loss: 565.775556 saving model Epoch 3 training loss: 11929.737305 [ 1280/11582] training loss: 20939.066406 [ 2560/11582] training loss: 16877.222656 [ 3840/11582] training loss: 14261.367188 [ 5120/11582] training loss: 14600.246094 [ 6400/11582] training loss: 10763.610352 [ 7680/11582] training loss: 9552.773438 [ 8960/11582] training loss: 10653.466797 [10240/11582] training loss: 18036.388672 [11520/11582] training avg loss: 196.873011 [11582/11582] validation avg loss: 516.859718 saving model Epoch 4 training loss: 10483.447266 [ 1280/11582] training loss: 12363.435547 [ 2560/11582] training loss: 9940.042969 [ 3840/11582] training loss: 10010.521484 [ 5120/11582] training loss: 7455.038086 [ 6400/11582] training loss: 12795.957031 [ 7680/11582] training loss: 13159.134766 [ 8960/11582] training loss: 7974.393555 [10240/11582] training loss: 10757.263672 [11520/11582] training avg loss: 158.359383 [11582/11582] validation avg loss: 643.760141 saving model
**** Running model on the testing dataset **** Test total loss: 803689.1726624109 Test avg loss: 407.96404703675677
This model has a slightly lower learning rate and performs okay given the other models that were trained. It still definitely overfits on the training dataset.
learningRate = 5e-4 # Set this value between 1e-2 to 1e-4
weightDecay = 1e-3 # Set this value between 1e-2 to 1e-4
momentum = 0.9 # Adding momentum to potentially improve training using SGD
epochs = 5
model = ImageRegModel()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=momentum)
res = trainingLoop(model, optimizer, epochs, dataloader_train, dataloader_val, dataloader_test)
**** Starting run_1699669004 **** Epoch 0 training loss: 79456.203125 [ 1280/11582] training loss: 42174.625000 [ 2560/11582] training loss: 27218.302734 [ 3840/11582] training loss: 23047.146484 [ 5120/11582] training loss: 21589.144531 [ 6400/11582] training loss: 14941.573242 [ 7680/11582] training loss: 24966.425781 [ 8960/11582] training loss: 16108.666992 [10240/11582] training loss: 10198.508789 [11520/11582] training avg loss: 509.066212 [11582/11582] validation avg loss: 596.000524 saving model Epoch 1 training loss: 10408.637695 [ 1280/11582] training loss: 8390.769531 [ 2560/11582] training loss: 7924.521484 [ 3840/11582] training loss: 6618.127930 [ 5120/11582] training loss: 6724.537598 [ 6400/11582] training loss: 7050.188477 [ 7680/11582] training loss: 5569.576660 [ 8960/11582] training loss: 8048.165527 [10240/11582] training loss: 6130.970703 [11520/11582] training avg loss: 142.594085 [11582/11582] validation avg loss: 655.230374 saving model Epoch 2 training loss: 5644.655273 [ 1280/11582] training loss: 7691.002930 [ 2560/11582] training loss: 6768.528809 [ 3840/11582] training loss: 4619.422363 [ 5120/11582] training loss: 4248.875977 [ 6400/11582] training loss: 5267.682129 [ 7680/11582] training loss: 4574.267090 [ 8960/11582] training loss: 4562.771484 [10240/11582] training loss: 3913.964844 [11520/11582] training avg loss: 77.952699 [11582/11582] validation avg loss: 666.102900 saving model Epoch 3 training loss: 5968.613770 [ 1280/11582] training loss: 3990.166016 [ 2560/11582] training loss: 5184.181152 [ 3840/11582] training loss: 3012.883789 [ 5120/11582] training loss: 4130.477539 [ 6400/11582] training loss: 3266.033203 [ 7680/11582] training loss: 4540.047852 [ 8960/11582] training loss: 2861.265137 [10240/11582] training loss: 5284.181152 [11520/11582] training avg loss: 69.575277 [11582/11582] validation avg loss: 721.667421 saving model Epoch 4 training loss: 3087.774170 [ 1280/11582] training loss: 3638.944824 [ 2560/11582] training loss: 4175.141602 [ 3840/11582] training loss: 3144.037598 [ 5120/11582] training loss: 3674.964355 [ 6400/11582] training loss: 2628.866699 [ 7680/11582] training loss: 3304.237549 [ 8960/11582] training loss: 3734.013916 [10240/11582] training loss: 3269.593750 [11520/11582] training avg loss: 61.360194 [11582/11582] validation avg loss: 667.606290 saving model
**** Running model on the testing dataset **** Test total loss: 1327409.846747905 Test avg loss: 673.8121049481751
This model has a lower learning and actually overfits more than any other model on the training set. For example, you can see from the training output images that it does a great job of predicting the x,y shift, but the validation loss continues to grow every single epoch of training and actually has the higher validation loss of all the models I trained. Unsurprisingly it also has the highest loss on the testing dataset.
learningRate = 1e-4 # Set this value between 1e-2 to 1e-4
weightDecay = 1e-3 # Set this value between 1e-2 to 1e-4
momentum = 0.9 # Adding momentum to potentially improve training using SGD
epochs = 5
model = ImageRegModel()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=momentum)
res = trainingLoop(model, optimizer, epochs, dataloader_train, dataloader_val, dataloader_test)
**** Starting run_1699670190 **** Epoch 0 training loss: 68465.617188 [ 1280/11582] training loss: 77228.250000 [ 2560/11582] training loss: 26948.089844 [ 3840/11582] training loss: 16732.226562 [ 5120/11582] training loss: 20831.316406 [ 6400/11582] training loss: 13769.865234 [ 7680/11582] training loss: 15971.530273 [ 8960/11582] training loss: 12591.771484 [10240/11582] training loss: 13210.708984 [11520/11582] training avg loss: 513.842127 [11582/11582] validation avg loss: 1321.640993 saving model Epoch 1 training loss: 7669.704102 [ 1280/11582] training loss: 8875.731445 [ 2560/11582] training loss: 7454.131836 [ 3840/11582] training loss: 6147.046387 [ 5120/11582] training loss: 5487.067383 [ 6400/11582] training loss: 4976.541504 [ 7680/11582] training loss: 4871.927734 [ 8960/11582] training loss: 4083.368164 [10240/11582] training loss: 4412.645996 [11520/11582] training avg loss: 100.845307 [11582/11582] validation avg loss: 1492.342193 saving model Epoch 2 training loss: 3733.685303 [ 1280/11582] training loss: 3655.067871 [ 2560/11582] training loss: 5009.525879 [ 3840/11582] training loss: 4832.305664 [ 5120/11582] training loss: 3916.507568 [ 6400/11582] training loss: 3347.320557 [ 7680/11582] training loss: 2862.141357 [ 8960/11582] training loss: 4150.048828 [10240/11582] training loss: 3157.463379 [11520/11582] training avg loss: 61.470535 [11582/11582] validation avg loss: 1500.884460 saving model Epoch 3 training loss: 3144.101074 [ 1280/11582] training loss: 5238.327148 [ 2560/11582] training loss: 2858.953125 [ 3840/11582] training loss: 4514.028320 [ 5120/11582] training loss: 3152.369629 [ 6400/11582] training loss: 4470.482422 [ 7680/11582] training loss: 4221.897461 [ 8960/11582] training loss: 2304.720947 [10240/11582] training loss: 2743.229492 [11520/11582] training avg loss: 52.448796 [11582/11582] validation avg loss: 1597.215967 saving model Epoch 4 training loss: 2613.798096 [ 1280/11582] training loss: 2060.591309 [ 2560/11582] training loss: 2915.719727 [ 3840/11582] training loss: 2859.226562 [ 5120/11582] training loss: 2621.983887 [ 6400/11582] training loss: 3310.562500 [ 7680/11582] training loss: 2845.667969 [ 8960/11582] training loss: 3542.064453 [10240/11582] training loss: 4319.824219 [11520/11582] training avg loss: 46.849657 [11582/11582] validation avg loss: 1623.546238 saving model
**** Running model on the testing dataset **** Test total loss: 3005546.7939618826 Test avg loss: 1525.658271046641
This is my best performing model on the validation dataset and the testing dataset. It has the same hyperparameters as the first model I trained above, except for different randomness which has a big impact on how well the model performs.
In the images below, you can see that some x, y shift predictions are somewhat close but many are off. The model still does not generalize as well as I would like, but you can see that it is approaching closer to the correct values on the test set as compared to the other trained models.
Below also shows examples from the training set and the predicted x, y shift from the trained model. You can see that it does a pretty good job of predicting the x, y shift. Besides comparing the numbers, it is interesting to look at the corners to see how well they match up with the actual aligned images corners. If I allowed the model to keep training for additional epochs the performance on the training images would likely get better, but it would potentially overfit more on the training set and have worse generalization.
learningRate = 1e-3 # Set this value between 1e-2 to 1e-4
weightDecay = 1e-3 # Set this value between 1e-2 to 1e-4
momentum = 0.9 # Adding momentum to potentially improve training using SGD
epochs = 5
model = ImageRegModel()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=momentum)
res = trainingLoop(model, optimizer, epochs, dataloader_train, dataloader_val, dataloader_test)
**** Starting run_1699671338 **** Epoch 0 training loss: 88116.187500 [ 1280/11582] training loss: 77907.648438 [ 2560/11582] training loss: 70020.101562 [ 3840/11582] training loss: 58398.656250 [ 5120/11582] training loss: 48831.027344 [ 6400/11582] training loss: 47674.550781 [ 7680/11582] training loss: 44285.687500 [ 8960/11582] training loss: 34899.628906 [10240/11582] training loss: 33032.968750 [11520/11582] training avg loss: 900.148294 [11582/11582] validation avg loss: 452.453043 saving model Epoch 1 training loss: 31167.246094 [ 1280/11582] training loss: 28969.957031 [ 2560/11582] training loss: 31288.349609 [ 3840/11582] training loss: 35778.121094 [ 5120/11582] training loss: 25937.222656 [ 6400/11582] training loss: 32856.765625 [ 7680/11582] training loss: 24748.031250 [ 8960/11582] training loss: 28179.062500 [10240/11582] training loss: 23700.929688 [11520/11582] training avg loss: 472.206713 [11582/11582] validation avg loss: 393.551325 saving model Epoch 2 training loss: 29874.066406 [ 1280/11582] training loss: 22646.894531 [ 2560/11582] training loss: 20878.164062 [ 3840/11582] training loss: 16952.248047 [ 5120/11582] training loss: 26861.882812 [ 6400/11582] training loss: 29055.068359 [ 7680/11582] training loss: 21642.962891 [ 8960/11582] training loss: 17785.140625 [10240/11582] training loss: 19964.476562 [11520/11582] training avg loss: 341.459562 [11582/11582] validation avg loss: 374.387541 saving model Epoch 3 training loss: 17834.658203 [ 1280/11582] training loss: 22931.933594 [ 2560/11582] training loss: 18792.798828 [ 3840/11582] training loss: 17660.791016 [ 5120/11582] training loss: 14740.080078 [ 6400/11582] training loss: 25890.175781 [ 7680/11582] training loss: 17117.923828 [ 8960/11582] training loss: 16177.603516 [10240/11582] training loss: 12941.824219 [11520/11582] training avg loss: 276.500388 [11582/11582] validation avg loss: 326.446480 saving model Epoch 4 training loss: 12682.767578 [ 1280/11582] training loss: 13808.327148 [ 2560/11582] training loss: 15159.840820 [ 3840/11582] training loss: 15742.969727 [ 5120/11582] training loss: 14346.420898 [ 6400/11582] training loss: 18393.556641 [ 7680/11582] training loss: 12164.410156 [ 8960/11582] training loss: 15992.047852 [10240/11582] training loss: 10414.591797 [11520/11582] training avg loss: 209.444935 [11582/11582] validation avg loss: 447.030639 saving model
**** Running model on the testing dataset **** Test total loss: 767586.9047348499 Test avg loss: 389.63802270804564
This is a training run on the second model that does not have the MaxPool2d and you can see that it doesn't really train very well even on the training set. This is the worst performing model out of these ones that I trained.
learningRate = 1e-3 # Set this value between 1e-2 to 1e-4
weightDecay = 1e-3 # Set this value between 1e-2 to 1e-4
momentum = 0.9 # Adding momentum to potentially improve training using SGD
epochs = 5
model = ImageRegModel2()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=momentum)
res = trainingLoop(model, optimizer, epochs, dataloader_train, dataloader_val, dataloader_test)
**** Starting run_1699672820 **** Epoch 0 training loss: 75970.250000 [ 1280/11582] training loss: 81473.421875 [ 2560/11582] training loss: 81341.179688 [ 3840/11582] training loss: 89989.671875 [ 5120/11582] training loss: 71742.703125 [ 6400/11582] training loss: 79673.835938 [ 7680/11582] training loss: 89097.187500 [ 8960/11582] training loss: 82719.093750 [10240/11582] training loss: 74823.351562 [11520/11582] training avg loss: 1250.723136 [11582/11582] validation avg loss: 1254.765407 saving model Epoch 1 training loss: 87450.101562 [ 1280/11582] training loss: 85882.921875 [ 2560/11582] training loss: 72148.984375 [ 3840/11582] training loss: 79987.851562 [ 5120/11582] training loss: 77356.343750 [ 6400/11582] training loss: 75754.203125 [ 7680/11582] training loss: 70872.867188 [ 8960/11582] training loss: 81310.187500 [10240/11582] training loss: 87843.390625 [11520/11582] training avg loss: 1241.940776 [11582/11582] validation avg loss: 1255.173538 saving model Epoch 2 training loss: 84613.171875 [ 1280/11582] training loss: 78539.453125 [ 2560/11582] training loss: 83705.765625 [ 3840/11582] training loss: 84988.804688 [ 5120/11582] training loss: 86753.968750 [ 6400/11582] training loss: 81630.789062 [ 7680/11582] training loss: 75176.546875 [ 8960/11582] training loss: 89023.757812 [10240/11582] training loss: 90950.835938 [11520/11582] training avg loss: 1241.986279 [11582/11582] validation avg loss: 1254.861140 saving model Epoch 3 training loss: 82113.062500 [ 1280/11582] training loss: 80315.546875 [ 2560/11582] training loss: 81476.609375 [ 3840/11582] training loss: 91044.203125 [ 5120/11582] training loss: 71228.875000 [ 6400/11582] training loss: 91647.906250 [ 7680/11582] training loss: 79137.218750 [ 8960/11582] training loss: 64700.242188 [10240/11582] training loss: 68394.656250 [11520/11582] training avg loss: 1242.063300 [11582/11582] validation avg loss: 1254.716023 saving model Epoch 4 training loss: 63781.859375 [ 1280/11582] training loss: 77190.781250 [ 2560/11582] training loss: 89644.351562 [ 3840/11582] training loss: 84635.468750 [ 5120/11582] training loss: 78793.054688 [ 6400/11582] training loss: 86309.984375 [ 7680/11582] training loss: 84505.750000 [ 8960/11582] training loss: 80304.968750 [10240/11582] training loss: 78859.703125 [11520/11582] training avg loss: 1241.967747 [11582/11582] validation avg loss: 1254.796866 saving model
**** Running model on the testing dataset **** Test total loss: 2392728.727073461 Test avg loss: 1214.583110189574